#include "trainthread.h"
#include "mainwindow.h"
#include "surface.h"
#include "elm.h"

TrainThread::TrainThread(MainWindow *window, Surface *surface, ELM *network):
    mWindow(window),
    mSurface(surface),
    mNetwork(network)
{
    mClassCnt=mWindow->getClassifierCnt();
    mColors=new QRgb[static_cast<uint>(mClassCnt)];
    for(uint i=0;i<mClassCnt;i++)
    {
        mColors[i]=mWindow->getClassifierColor(static_cast<int>(i)).rgb();

    }

}

void TrainThread::run()
{

    mNetwork->randomWeightAndBias();
    bool trainOK=mNetwork->train();
    if(trainOK){
        int height=mSurface->height();
        int width=mSurface->width();
        height=height/4*4;
        width=width/4*4;
        QImage image(width,height,QImage::Format_RGB888);
        uchar* data=image.bits();
        for(int i=0;i<width;i++)
            for(int j=0;j<height;j++)
            {
                double input[2];
                input[0]=i*1.0/width;
                input[1]=1.0-j*1.0/height;
                mNetwork->test(input);
                double* output=mNetwork->testOutput();
                QRgb color=getColor(output);
                //image.setPixelColor(i,j,QColor(color));
                setData(data,i,j,width,
                        static_cast<uchar>(qRed(color)),
                        static_cast<uchar>(qGreen(color)),
                        static_cast<uchar>(qBlue(color)));
            }
        mSurface->setShowMap(true);
        emit sigMap(QPixmap::fromImage(image));

    }
      emit sigTrainFinished();
}
QRgb TrainThread::getColor(double* output)
{

    int idx;
    findMax(output,&idx);
    return mColors[idx];
}
double TrainThread::findMax(double *d, int *i)
{
    int idx=-1;
    double value=-1;
    for(uint i=0;i<mClassCnt;i++)
    {
        if(d[i]>value)
        {
            value=d[i];
            idx=static_cast<int>(i);
        }
    }
    *i=idx;
    return value;
}
void TrainThread::setData(uchar *data, int x, int y, int width,uchar r,uchar g,uchar b)
{
    *(data+(y*width*3)+x*3)=r;
    *(data+(y*width*3)+x*3+1)=g;
    *(data+(y*width*3)+x*3+2)=b;
}
